Cluster annotation¶

In [1]:
%load_ext autoreload
%autoreload 2
In [2]:
from scip_workflows.common import *
In [7]:
import pickle
import anndata
import scanpy
import shap
from matplotlib.gridspec import GridSpec
from matplotlib.patches import ConnectionStyle
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_selection import mutual_info_classif
from sklearn.metrics import balanced_accuracy_score
from sklearn.model_selection import train_test_split

from scip_workflows.core import plot_gate_czi

shap.initjs()
In [30]:
plt.rcParams["figure.dpi"] = 200
In [4]:
try:
    adata = snakemake.input.adata
    output_three = snakemake.output[0]
    output_cd15_cd45 = snakemake.output[1]
    output_cd15_siglec8 = snakemake.output[2]
    output_unclassified = snakemake.output[3]
    image_root = snakemake.input.image_root
except NameError:
    image_root = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800")
    data_dir = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800/scip/061020221736/")
    adata = data_dir / "adata.pickle"
    output_three = data_dir / "figures" / "cluster_panels.png"
    output_cd15_cd45 = data_dir / "figures" / "cd15_vs_cd45_facets.png"
    output_cd15_siglec8 = data_dir / "figures" / "cd15_vs_siglec8_facets.png"
    output_unclassified = data_dir / "figures" / "unclassified_cluster.png"
In [5]:
def map_names(a):
    return {
        "feat_combined_sum_DAPI": "DAPI",
        "feat_combined_sum_EGFP": "CD45",
        "feat_combined_sum_RPe": "Siglec 8",
        "feat_combined_sum_APC": "CD15",
    }[a]
In [8]:
with open(adata, "rb") as fh:
    adata = pickle.load(fh)
In [9]:
adata.obs.meta_path = adata.obs.meta_path.apply(
    lambda p: image_root.joinpath(*Path(p).parts[Path(p).parts.index("800") + 1 :])
)
In [10]:
markers = [
    col
    for col in adata.var_names
    if col.startswith(
        tuple("feat_combined_sum_%s" % m for m in ("EGFP", "RPe", "APC", "DAPI"))
    )
]
In [11]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(
    adata,
    markers,
    groupby="leiden",
    dendrogram=True,
    vmin=-2,
    vmax=2,
    cmap="RdBu_r",
    ax=axes[0],
    show=False,
    use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
    map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="leiden", legend_loc="on data", ax=axes[1], show=False)
seaborn.countplot(data=adata.obs, x="leiden", hue="meta_replicate", ax=axes[2])
WARNING: dendrogram data not found (using key=dendrogram_leiden). Running `sc.tl.dendrogram` with default parameters. For fine tuning it is recommended to run `sc.tl.dendrogram` independently.
Out[11]:
<AxesSubplot:xlabel='leiden', ylabel='count'>
In [12]:
adata.obs["leiden_merged"] = adata.obs.leiden.map(
    lambda a: a if a in [str(i) for i in [2, 4, 6, 8]] else "1"
)
In [13]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(
    adata,
    markers,
    groupby="leiden_merged",
    dendrogram=True,
    vmin=-2,
    vmax=2,
    cmap="RdBu_r",
    ax=axes[1],
    show=False,
    use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
    map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="leiden_merged", ax=axes[2], show=False)
seaborn.countplot(data=adata.obs, x="leiden_merged", hue="meta_replicate", ax=axes[0])
WARNING: dendrogram data not found (using key=dendrogram_leiden_merged). Running `sc.tl.dendrogram` with default parameters. For fine tuning it is recommended to run `sc.tl.dendrogram` independently.
Out[13]:
<AxesSubplot:xlabel='leiden_merged', ylabel='count'>
In [14]:
scanpy.pl.scatter(
    adata,
    x="feat_combined_sum_EGFP",
    y="feat_combined_sum_APC",
    color="leiden_merged",
    legend_loc="on data",
)
In [15]:
grid = seaborn.FacetGrid(
    data=scanpy.get.obs_df(
        adata,
        keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC", "leiden_merged"],
        use_raw=True,
    ),
    col="leiden_merged",
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
    seaborn.scatterplot(
        data=scanpy.get.obs_df(
            adata,
            keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC"],
            use_raw=True,
        ),
        x="feat_combined_sum_EGFP",
        y="feat_combined_sum_APC",
        color="grey",
        s=0.5,
        alpha=0.5,
        ax=ax,
    )
grid.map_dataframe(
    seaborn.scatterplot, x="feat_combined_sum_EGFP", y="feat_combined_sum_APC", s=1.5
)
for ax in grid.axes.ravel():
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_xlabel("CD45")
    ax.set_ylabel("CD15")

plt.savefig(output_cd15_cd45, bbox_inches="tight", pad_inches=0, dpi=200)
In [16]:
scanpy.pl.scatter(
    adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
    x="feat_combined_sum_RPe",
    y="feat_combined_sum_APC",
    color="leiden",
    legend_loc="on data",
)
In [17]:
grid = seaborn.FacetGrid(
    data=scanpy.get.obs_df(
        adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
        keys=["feat_combined_sum_RPe", "feat_combined_sum_APC", "leiden_merged"],
        use_raw=True,
    ),
    col="leiden_merged",
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
    seaborn.scatterplot(
        data=scanpy.get.obs_df(
            adata[adata.obs.leiden.isin(["1", "6", "8", "9"])],
            keys=["feat_combined_sum_RPe", "feat_combined_sum_APC"],
            use_raw=True,
        ),
        x="feat_combined_sum_RPe",
        y="feat_combined_sum_APC",
        color="grey",
        s=0.5,
        alpha=0.5,
        ax=ax,
    )
grid.map_dataframe(
    seaborn.scatterplot, x="feat_combined_sum_RPe", y="feat_combined_sum_APC", s=1.5
)
for ax in grid.axes.ravel():
    ax.set_yticks([])
    ax.set_xticks([])
    ax.set_xlabel("Siglec 8")
    ax.set_ylabel("CD15")

plt.savefig(output_cd15_siglec8, bbox_inches="tight", pad_inches=0, dpi=200)

SHAP¶

In [ ]:
X_train, X_test, y_train, y_test = train_test_split(
    adata[:, adata.var.selected_corr],
    adata.obs["leiden_merged"],
    test_size=0.1,
    stratify=adata.obs["leiden_merged"],
)
In [ ]:
model = RandomForestClassifier(n_estimators=50, random_state=0).fit(
    X_train.to_df(), y_train.values
)
In [ ]:
preds = model.predict(X_test.to_df())
balanced_accuracy_score(y_test.values, preds)
In [ ]:
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test.to_df())
In [ ]:
y_train.cat.categories
In [ ]:
shap.plots.beeswarm(shap_values[..., 3])
In [ ]:
adata.obs["meta_masks"] = adata.obs[["meta_scene", "meta_tile"]].apply(
    lambda r: str(data_dir / "masks" / "%s_%s.npy") % (r.meta_scene, r.meta_tile),
    axis=1,
)
In [ ]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "6",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=50,
    masks_path_col="meta_masks",
)
In [ ]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "6",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=50,
)
plt.savefig(output_unclassified)
In [ ]:
quantiles = adata.to_df().filter(regex="feat_combined_sum").quantile([0.05, 0.95])
extent = quantiles.loc[
    :,
    [
        "feat_combined_sum_%s" % s
        for s in ["DAPI", "EGFP", "RPe", "APC", "Bright", "Oblique", "PGC"]
    ],
].T.values
In [ ]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "6",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=50,
    extent=extent,
)
In [ ]:
scanpy.pl.violin(adata, "feat_combined_sum_APC", groupby="leiden_merged")
In [ ]:
shap.plots.scatter(shap_values[..., "feat_combined_sum_APC", 4])
In [ ]:
shap.plots.beeswarm(shap_values[..., 5])
In [ ]:
plot_gate_czi(
    sel=adata.obs["leiden"] == "9",
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6],
    maxn=30,
    masks_path_col="meta_masks",
)

Cluster annotation¶

In [18]:
# create a dictionary to map cluster to annotation label
cluster2annotation = {
    "1": "granulocytes",
    "8": "eosinophils",
    "4": "monocytes",
    "2": "lymphocytes",
    "6": "unclassified",
}

# add a new `.obs` column called `cell type` by mapping clusters to annotation using pandas `map` function
cat_type = pandas.CategoricalDtype(
    ["monocytes", "lymphocytes", "granulocytes", "eosinophils", "unclassified"],
    ordered=True,
)
adata.obs["cell type"] = (
    adata.obs["leiden_merged"].map(cluster2annotation).astype(cat_type)
)
In [20]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
ax = scanpy.pl.matrixplot(
    adata,
    markers,
    groupby="cell type",
    dendrogram=False,
    vmin=-2,
    vmax=2,
    cmap="RdBu_r",
    ax=axes[1],
    show=False,
    use_raw=False,
)
ax["mainplot_ax"].set_xticklabels(
    map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels())
)
scanpy.pl.umap(adata, color="cell type", ax=axes[2], show=False, palette="tab10")
seaborn.countplot(data=adata.obs, y="cell type", hue="meta_replicate", ax=axes[0])

axes[0].set_title("Cell type counts")
axes[1].set_title("Marker intensity")
axes[2].set_title("UMAP")
axes[0].legend(title="Replicate")

plt.savefig(output_three, bbox_inches="tight", pad_inches=0, dpi=200)
In [27]:
counts = adata.obs["cell type"].value_counts().to_frame()
counts["fraction"] = counts["cell type"] / counts["cell type"].sum()
counts.columns = ["Count", "Fraction"]
print(counts.style.to_latex(hrules=True))
\begin{tabular}{lrr}
\toprule
{} & {Count} & {Fraction} \\
\midrule
granulocytes & 21725 & 0.730989 \\
lymphocytes & 4904 & 0.165007 \\
monocytes & 1737 & 0.058445 \\
unclassified & 1031 & 0.034690 \\
eosinophils & 323 & 0.010868 \\
\bottomrule
\end{tabular}

In [23]:
quantiles = adata.to_df().filter(regex="feat_combined_sum").quantile([0.05, 0.95])
In [33]:
plot_gate_czi(
    sel=adata.obs["cell type"] == "unclassified", 
    df=adata.obs,
    channels=[0,1,2,3,4,5,6],
    maxn=40
)
plt.savefig(output_unclassified, bbox_inches="tight")
0 P2-D2 0 P3-D1 0 P3-D5 0 P4-D3 0 P4-D5 0 P5-D4 0 P6-D4 0 P9-D2 0 P9-D5 0 P10-D1 0 P10-D3 0 P12-D1 0 P12-D4 0 P13-D1 0 P13-D3 0 P13-D5 0 P14-D1 0 P14-D5 0 P15-D1 0 P15-D4 0 P17-D2 0 P18-D2 0 P19-D3 0 P19-D4 0 P20-D3 0 P20-D5 0 P21-D3 0 P22-D1 0 P22-D4 0 P22-D5 0 P23-D3 0 P23-D4 0 P24-D1 0 P24-D3 0 P24-D4 
In [ ]: